"""

CN with ENKF

plot the graphs for different Q and R


"""
from numpy import linalg as LA
from scipy import stats
import torch, os, cv2
import numpy as np
from matplotlib import pyplot as plt, patches
from torch_geometric.data import Data
from torch import nn
import pandas as pd

class Encoder(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(Encoder, self).__init__()

        self.output_size = output_size

        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):

        x = self.layers(x)
        return x


class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Decoder, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, 1),  # speedX
        )

    def forward(self, x):

        return self.layers(x)


class CN(nn.Module):
    def __init__(self):
        super(CN, self).__init__()
        object_dim = 4  # node features
        relation_dim = 1  # edge features
        effect_dim = 50
        x_external_dim = 0
        self.encoder_model = Encoder(2 * object_dim + relation_dim, effect_dim, 150)
        self.decoder_model = Decoder(object_dim + effect_dim + x_external_dim, 100)

    def forward(self, objects, sender_relations, receiver_relations, relation_info):

        senders = torch.matmul(torch.t(sender_relations.float()), objects.float())
        receivers = torch.matmul(torch.t(receiver_relations.float()), objects.float())
        m = torch.cat((receivers, senders, relation_info), 1)
        effects = self.encoder_model(m.float()) #torch.Size([5800, 9])
        effect_receivers = torch.matmul(receiver_relations.float(), effects) #[5800, 50] -> [3000,50]

        aggregation_result = torch.cat((objects, effect_receivers), 1) #[3000, 54]

        predicted = self.decoder_model(aggregation_result)
        return predicted

def generate_position_graph(x_list, node_size, total_time_step, folder_path, x_min, x_max):
    """ generate images to be converted to video to verify the data

    :param x_list: position lists containing x position for each time for each sea ice
    :param node_size: the total number of sea floes
    :param total_time_step: the total time steps
    :param folder_path: the folder path that used to save the raw images
    :return:
    """
    for t in range(0, total_time_step, 10):
        fig, ax = plt.subplots(num=1, clear=True)
        fig.set_size_inches(100, 4) #width, height

        ax.set_xlim(0, 100)
        ax.set_ylim(-2, 2)

        # ax.set_aspect(1.0)
        ax.set_title(f"One dimensional simulation at {t} step")
        for node in range(node_size):
            circle1 = patches.Circle((x_list[node][t], 0), radius=1, edgecolor='red')
            ax.add_patch(circle1)
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
            plt.savefig(f'{folder_path}/{t}.png')
        else:
            plt.savefig(f'{folder_path}/{t}.png')

def generate_video(image_folder, num, total_time, step):
    """ generate video based on the images from image_folder

    :param image_folder: input image folder
    :param num: the label for different situations
    :param total_time: the total time steps
    :param step: jump steps in time
    :return:
    """
    video_name = f'{image_folder}/1D_simulation_{num}.avi'

    images = [str(i) + '.png' for i in range(0, total_time, step)]
    frame = cv2.imread(os.path.join(image_folder, images[0]))
    height, width, layers = frame.shape
    height, width = int(height / 10), int(width/10)
    video = cv2.VideoWriter(video_name, 0, 100, (width, height)) # 100 means rendering 100 images per second

    for image in images:
        # let's downscale the image using new  width and height
        down_width = width
        down_height = height
        down_points = (down_width, down_height)

        video.write(cv2.resize(cv2.imread(os.path.join(image_folder, image)), down_points, interpolation=cv2.INTER_LINEAR))

    cv2.destroyAllWindows()
    video.release()


def roll_out_data_2steps_v(data_path, split):
    data_list = []

    if split == 'train':
        total_simulation = 1000
    else:
        total_simulation = 100
    total_time_step = 10000
    node_size = 12
    data_fp = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="c",
                     shape=(total_simulation, total_time_step, node_size, 3))


    for simulation in range(total_simulation):
        temp_position = data_fp[simulation, :, :, 0] # x,
        temp_velocity = data_fp[simulation, :, :, 1] # velocity

        data = {"position": temp_position, 'velocity': temp_velocity}
        data_list.append(data)

    return data_list



def ENKF(model, data, model_noise, observation_noise):
    """ generate the rollout positions based on the model and ENKF

    return the velocity and position traj for each E
    """
    # set up the parameters for ENKF
    observation_freq = 100  # observe the data every 100 dt == 1e-2  Change this to change the observation frequency
    ensemble_number = 100  # Change this to change the ensemble number
    Q = np.diag([model_noise for _ in range(12)])

    R = np.diag([observation_noise for _ in range(10)])
    Hj = np.array([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   ]) # observation model

    total_time = 10000
    node_size = 12
    num_nodes = 12
    traj = [0 for i in range(ensemble_number)]
    velocity_traj = [0 for i in range(ensemble_number)]
    for ensemble in range(ensemble_number):
        traj[ensemble] = np.array(data["position"][:2, :])  # [x_t, num_nodes]
        velocity_traj[ensemble] = np.array(data["velocity"][:2, :])
    dt = 1e-4
    device = next(model.parameters()).device

    n_objects = 12
    n_relations = (n_objects - 1) * 2  # 6

    # Construct receiver_relations and sender_relations
    receiver_relations = np.zeros((n_objects, n_relations), dtype=float)
    sender_relations = np.zeros((n_objects, n_relations), dtype=float)
    for i in range(1, n_objects - 1):  # assign the non-boundary nodes first (node1 to node 28)
        receiver_relations[i, 2 * i - 2] = 1.0
        receiver_relations[i, 2 * i + 1] = 1.0

        sender_relations[i, 2 * i] = 1.0
        sender_relations[i, 2 * i - 1] = 1.0

    # left boundary
    receiver_relations[0, 1] = 1.0
    sender_relations[0, 0] = 1.0

    # right boundary

    receiver_relations[n_objects - 1, n_relations - 2] = 1.0

    sender_relations[n_objects - 1, n_relations - 1] = 1.0

    sender_relations, receiver_relations = torch.from_numpy(sender_relations), torch.from_numpy(receiver_relations)
    sender_relations = sender_relations.cuda()
    receiver_relations = receiver_relations.cuda()

    for t in range(2, total_time):
        with torch.no_grad():
            for ensemble in range(ensemble_number):

                temp_x_1 = traj[ensemble][-2, :].reshape(num_nodes, 1)
                temp_x_2 = traj[ensemble][-1, :].reshape(num_nodes, 1)
                velocity = velocity_traj[ensemble][-1, :].reshape(num_nodes, 1)
                radius = np.ones((node_size, 1))
                temp_x = np.concatenate((temp_x_1, temp_x_2, velocity, radius), axis=1)

                relation_distance = temp_x_2[1:, 0] - temp_x_2[:-1, 0]
                relation_distance2 = np.zeros((node_size - 1, 2))
                relation_distance2[:, 0] = relation_distance
                relation_distance2[:, 1] = -relation_distance
                relation_distance2 = relation_distance2.flatten()  # distance feature
                relation_distance2 = relation_distance2.reshape(2 * (node_size - 1), 1)

                edge_features = relation_distance2  # [num_edges, num_edge_features]

                graph = Data(x=torch.from_numpy(temp_x).float(), edge_attr=torch.from_numpy(edge_features).float())
                graph = graph.to(device)
                new_velocity = model(graph.x, sender_relations, receiver_relations, graph.edge_attr).cpu()


                new_velocity[0] = 0
                new_velocity[-1] = 0
                new_position = torch.tensor(temp_x_2) + new_velocity * dt

                qi = np.random.multivariate_normal([0] * 12, Q,
                                                   1)  # model error mean 0 cov diagnoal
                new_position = new_position + qi.T
                new_position[0] = -1
                new_position[-1] = 101
                new_position = new_position.reshape(-1, num_nodes) #(1, 12)
                traj[ensemble] = torch.cat((torch.tensor(traj[ensemble]), new_position), dim=0)
                velocity_traj[ensemble] = torch.cat((torch.tensor(velocity_traj[ensemble]), new_velocity.reshape(1, -1)), dim=0)


            if t % observation_freq == 0: #with observation ==0

                x = [traj[v][-1, 1:-1].numpy() for v in range(ensemble_number)]  # adapt to newer numpy version
                x = np.array(x).reshape(ensemble_number, 10)
                x2 = [velocity_traj[v][-1, 1:-1].numpy() for v in range(ensemble_number)]  # adapt to newer numpy version
                x2 = np.array(x2).reshape(ensemble_number, 10)

                x = np.concatenate((x, x2), axis=1)
                xi = x
                x_mean = np.mean(xi, axis=0)#.reshape(num_nodes - 2, 1)
                big_x_mean = np.tile(x_mean, (ensemble_number, 1))
                U = (xi - big_x_mean).T
                V = Hj @ U
                K = U @ V.T @ np.linalg.inv(V@V.T/(ensemble_number -1) + R)/(ensemble_number -1)
                y_o = np.array(data["position"][t, 1:-1]) # [x_t, num_nodes] observation
                big_y = np.tile(y_o, (ensemble_number, 1))

                ri = np.random.multivariate_normal([0] * 10, R,
                                                   ensemble_number)  # model error mean 0 cov diagnoal
                yi = big_y + ri
                yi = yi.T

                updated_X = xi.T + K@(yi - Hj @ xi.T)
                updated_X = updated_X.T

                for ensemble in range(ensemble_number):
                    traj[ensemble][-1, 1:-1] = torch.from_numpy(updated_X[ensemble, :10])
                    velocity_traj[ensemble][-1, 1:-1] = torch.from_numpy(updated_X[ensemble, 10:])


    traj = [[traj[v].numpy()] for v in range(ensemble_number)]  # adapt to newer numpy version
    traj = np.array(traj).reshape(ensemble_number, -1, 12)


    return traj

def rmse(predictions, targets):
    """

    :param predictions: vector
    :param targets: vector
    :return:
    """
    return np.sqrt(((predictions - targets) ** 2).mean())

def pattern_correlation(x, y):
    """
        The Pearson correlation coefficient [1] measures the linear relationship between two datasets. Like other correlation coefficients, this one varies between -1 and +1 with 0 implying no correlation. Correlations of -1 or +1 imply an exact linear relationship. Positive correlations imply that as x increases, so does y. Negative correlations imply that as x increases, y decreases.
    :param x:
    :param y:
    :return:Pearson correlation coefficient
    """
    res = stats.pearsonr(x, y)
    return res.statistic
def average_pattern_correlation(predict_traj, true_traj):
    """ sum all nodes Pearson correlation coefficient then / 10

    :param predict_traj: (total_time_steps, num_nodes)
    :param true_traj: (total_time_steps, num_nodes)
    :return:
    """

    num_nodes = predict_traj.shape[1]
    correlation = np.zeros(num_nodes)
    for i in range(num_nodes):
        correlation[i] = pattern_correlation(predict_traj[:, i], true_traj[:, i])

    average_pattern_correlation = correlation.mean()
    return average_pattern_correlation, correlation

def average_RMSE(predict_traj, true_traj, total_time_steps):
    """ sum all timestep RMSE then / 10000

    :param predict_traj: (total_time_steps, num_nodes)
    :param true_traj: (total_time_steps, num_nodes)
    :param total_time_steps: total_time_steps
    :return:
    """

    rmse_array = np.zeros(total_time_steps)
    for i in range(total_time_steps):
        rmse_array[i] = rmse(predict_traj[i], true_traj[i])

    aver_rmse = rmse_array.mean()
    return aver_rmse, rmse_array
def relative_accuracy(predict_traj, true_traj, total_time_steps):
    """ # input predicted position and output is to draw n lines with relative accuracy respective with time
    1 - (predicted - truth) / truth

    :param predict_traj: (total_time_steps, num_nodes)
    :param true_traj: (total_time_steps, num_nodes)
    :param total_time_steps: total_time_steps
    :return relative_accuracy_array: (total_time_steps, num_nodes)
    """

    average_accuracy_array = np.zeros((total_time_steps, num_nodes))
    for i in range(total_time_steps):
        for j in range(num_nodes):
            average_accuracy_array[i] = 1 - abs(predict_traj[i][j] - true_traj[i][j])/true_traj[i][j]

    return average_accuracy_array #(total_time_steps, )

def average_accuracy(predict_traj, true_traj, total_time_steps):
    """ # input predicted position and output is to draw n lines with relative accuracy respective with time with the avaerage version
    1 - (predicted - truth) / truth

    :param predict_traj: (total_time_steps, num_nodes)
    :param true_traj: (total_time_steps, num_nodes)
    :param total_time_steps: total_time_steps
    :return relative_accuracy_array: (total_time_steps, num_nodes)
    """

    relative_accuracy_array = np.zeros(total_time_steps)
    for i in range(total_time_steps):

        temp = 0
        for j in range(num_nodes):
            temp += 1 - abs(predict_traj[i][j] - true_traj[i][j])/true_traj[i][j]
        relative_accuracy_array[i] = temp /num_nodes


    return relative_accuracy_array #(total_time_steps, num_nodes)


if __name__ == '__main__':
    split_list = ['valid']
    model_noise_list = [0.1, 1, 2]
    observer_noise_list = [0.1, 1, 2]
    DATASET_NAME_org = f"CN_10"
    model_path = f"./models/{DATASET_NAME_org}"
    MODEL_NAME = f'{DATASET_NAME_org}_{the optimal checkpoint number}'  # change it to your optimal checkpoint number from the result from the compare_CN_30.py file
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    OUTPUT_DIR = model_path
    simulator = CN().to(device)
    for split in split_list:
        total_time_steps = 10000
        data_folder_path = f'./{data_IN}' #specify the data path
        rollout_dataset = roll_out_data_2steps_v(f'{data_folder_path}', split)  # ground truth
        PLOT_GRAPH = True
        total_loss = 0

        num_nodes = 10

        """ without knowing the position"""
        for i_dataset in range(len(rollout_dataset)):
            rollout_data = rollout_dataset[i_dataset]

            # enkf_traj = ENKF(simulator, rollout_data)
            for model_noise in model_noise_list:
                for observation_noise in observer_noise_list:
                    enkf_traj = ENKF(simulator, rollout_data, model_noise, observation_noise)

                    if one_ensemble:
                        x_mean = enkf_traj[ensemble_chosen_index, :, :]

                        DATASET_NAME = f'ENKF_{ensemble_chosen_index}'
                    else:
                        x_mean = np.mean(enkf_traj, axis=0)  # time, nodes
                        DATASET_NAME = 'ENKF_mean'
                    # remove boundary

                    x_mean = x_mean[:, 1:-1]
                    # rollout_data["position"] = rollout_data["position"][:, 1:-1]  # ground truth
                    ground_truth_data = rollout_data["position"][:, 1:-1]  # ground truth
                    enkf_RMSE, enkf_RMSE_array = average_RMSE(x_mean, ground_truth_data,
                                                              total_time_steps)  # sum all timestep then / 10000
                    _, correlation_array = average_pattern_correlation(x_mean, ground_truth_data)
                    relative_accuracy_array = relative_accuracy(x_mean, ground_truth_data, total_time_steps)
                    average_accuracy_array = average_accuracy(x_mean, ground_truth_data, total_time_steps)
                    if PLOT_GRAPH:
                        ################RMSE for position##################################
                        fig, ax = plt.subplots(num=1, clear=True)
                        # Add some text for labels, title and custom x-axis tick labels, etc.
                        ax.set_ylabel('RMSE')
                        ax.set_xlabel('Time')
                        # ax.set_title(f'RMSE for position in different time step \n ENKF: {enkf_RMSE} ')

                        ax.plot(enkf_RMSE_array, label=f'ENKF')
                        ax.legend(loc='upper right')
                        # ax.set_ylim(0, 1)
                        folder_path = f'./img/{DATASET_NAME}/{model_noise}_{observation_noise}/rmse'
                        if not os.path.exists(folder_path):
                            os.makedirs(folder_path)
                            plt.savefig(f'{folder_path}/{i_dataset}_rmse_position_{split}.png')
                        else:
                            plt.savefig(f'{folder_path}/{i_dataset}_rmse_position_{split}.png')

                        ################Plot for position##################################
                        for i in range(10):
                            fig, ax = plt.subplots(num=1, clear=True)
                            # Add some text for labels, title and custom x-axis tick labels, etc.
                            ax.set_ylabel('Position')
                            ax.set_xlabel('Time')

                            # ax.set_title(f'position in different time step ')
                            ax.plot(x_mean[:, i], label=f'ENKF')
                            ax.plot(ground_truth_data[:, i], label=f'True')
                            ax.legend(loc='upper right')
                            # ax.set_ylim(0, 1)

                            if not os.path.exists(folder_path):
                                os.makedirs(folder_path)
                                plt.savefig(f'{folder_path}/{i_dataset}_{i}_position_{split}.png')
                            else:
                                plt.savefig(f'{folder_path}/{i_dataset}_{i}_position_{split}.png')

                        ################Comparision for position for a certain node##################################
                        for node in range(num_nodes):
                            fig, ax = plt.subplots(num=1, clear=True)
                            # Add some text for labels, title and custom x-axis tick labels, etc.
                            ax.set_ylabel('Position')
                            # ax.set_title(f'ENKF Trajectory vs. Truth Trajectory')

                            ax.plot(x_mean[:, node], label=f'ENKF')
                            ax.plot(ground_truth_data[:, node], label=f'Ground Truth')
                            ax.legend(loc='upper right')
                            # ax.set_ylim(0, 1)
                            folder_path = f'./img/{DATASET_NAME}/{model_noise}_{observation_noise}/ENKF_Trajectory'
                            if not os.path.exists(folder_path):
                                os.makedirs(folder_path)
                                plt.savefig(f'{folder_path}/{i_dataset}_ENKF_Trajectory_{split}_{node}.png')
                            else:
                                plt.savefig(f'{folder_path}/{i_dataset}_ENKF_Trajectory_{split}_{node}.png')

                        ################PCC for position##################################
                        fig, ax = plt.subplots(num=1, clear=True)
                        # Add some text for labels, title and custom x-axis tick labels, etc.
                        ax.set_ylabel('PCC')
                        ax.set_xlabel('Index of Floes')
                        # ax.set_title(f'ENKF PCC for Floes')
                        ax.grid()

                        ax.scatter(np.arange(len(correlation_array)), correlation_array)
                        for i, txt in enumerate(correlation_array):
                            ax.annotate(str(round(txt, 3)), (i, correlation_array[i]))
                        # ax.legend(loc='upper right')
                        # ax.set_ylim(0.7, 1)
                        folder_path = f'./img/{DATASET_NAME}/{model_noise}_{observation_noise}/PCC'
                        if not os.path.exists(folder_path):
                            os.makedirs(folder_path)
                            plt.savefig(f'{folder_path}/{i_dataset}_PCC_{split}.png')
                        else:
                            plt.savefig(f'{folder_path}/{i_dataset}_PCC_{split}.png')

                        ################RMSE for position##################################
                        fig, ax = plt.subplots(num=1, clear=True)
                        # Add some text for labels, title and custom x-axis tick labels, etc.
                        ax.set_ylabel('RMSE')
                        # ax.set_title(f'RMSE for position in different time step \n ENKF: {enkf_RMSE} ')

                        ax.plot(enkf_RMSE_array, label=f'ENKF')
                        ax.legend(loc='upper right')
                        # ax.set_ylim(0, 1)
                        folder_path = f'./img/{DATASET_NAME}/{model_noise}_{observation_noise}/rmse'

                        if not os.path.exists(folder_path):
                            os.makedirs(folder_path)
                            plt.savefig(f'{folder_path}/{i_dataset}_rmse_position_{split}.png')
                        else:
                            plt.savefig(f'{folder_path}/{i_dataset}_rmse_position_{split}.png')

                        with open(f"{folder_path}/Output.txt", "w") as text_file:
                            text_file.write(f'RMSE for position in different time step \n ENKF: {enkf_RMSE} ')

                        ############## Relative accuracy regarding node########

                        for node in range(num_nodes):
                            fig, ax = plt.subplots(num=1, clear=True)
                            # Add some text for labels, title and custom x-axis tick labels, etc.
                            ax.set_ylabel('Relative accuracy')
                            ax.set_xlabel('Index of Floes')
                            # ax.set_title(f'ENKF PCC for Floes')
                            # ax.grid()
                            ax.plot(relative_accuracy_array[:, node], label=f'{node}')

                            ax.legend(loc='upper right')
                            # ax.set_ylim(0.7, 1)
                            folder_path = f'./img/{DATASET_NAME}/{model_noise}_{observation_noise}/Relative_accuracy'
                            if not os.path.exists(folder_path):
                                os.makedirs(folder_path)
                                plt.savefig(f'{folder_path}/{i_dataset}_Relative_{split}_{node}.png')
                            else:
                                plt.savefig(f'{folder_path}/{i_dataset}_Relative_{split}_{node}.png')

                        ############## Relative accuracy regarding time########

                        fig, ax = plt.subplots(num=1, clear=True)
                        # Add some text for labels, title and custom x-axis tick labels, etc.
                        ax.set_ylabel('RMSE')
                        # ax.set_title(f'RMSE for position in different time step \n ENKF: {enkf_RMSE} ')

                        ax.plot(average_accuracy_array, label=f'ENKF')
                        ax.legend(loc='upper right')
                        # ax.set_ylim(0, 1)
                        folder_path = f'./img/{DATASET_NAME}/{model_noise}_{observation_noise}/relative_average'

                        if not os.path.exists(folder_path):
                            os.makedirs(folder_path)
                            plt.savefig(f'{folder_path}/{i_dataset}_relative_average_{split}.png')
                        else:
                            plt.savefig(f'{folder_path}/{i_dataset}_relative_average_{split}.png')

                        with open(f"{folder_path}/Output.txt", "w") as text_file:
                            text_file.write(
                                f'RMSE for position in different time step \n ENKF: {average_accuracy_array.mean()} ')


